import pickle
import numpy as np

def load_cifar10_batch(file):
    with open(file, 'rb') as f:
        dict = pickle.load(f, encoding='bytes')
        X = dict[b'data'].reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
        Y = dict[b'labels']
        return X, Y

# 加载数据批次
X, Y = load_cifar10_batch('../datasets/cifar-10-batches-py/data_batch_1')

print(X.shape)  # 输出: (10000, 32, 32, 3)
print(len(Y))    # 输出: 10000